#!/usr/bin/python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#   http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
import re
import sys
import os
import subprocess
from threading import Timer
from xml.dom.minidom import parse
from cloudutils.configFileOps import configFileOps
from cloudutils.networkConfig import networkConfig

logging.basicConfig(filename='/var/log/libvirt/qemu-hook.log',
                    filemode='a',
                    format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger('qemu-hook')

customDir = "/etc/libvirt/hooks/custom"
customDirPermissions = 0o744
timeoutSeconds = 10 * 60
validQemuActions = ['prepare', 'start', 'started', 'stopped', 'release', 'migrate', 'restore', 'reconnect', 'attach']

def isOldStyleBridge(brName):
    if brName.find("cloudVirBr") == 0:
        return True
    else:
        return False

def isNewStyleBridge(brName):
    if brName.startswith('brvx-'):
        return False
    if re.match(r"br(\w+)-(\d+)", brName) == None:
        return False
    else:
        return True

def getGuestNetworkDevice():
    netlib = networkConfig()
    cfo = configFileOps("/etc/cloudstack/agent/agent.properties")
    guestDev = cfo.getEntry("guest.network.device")
    enslavedDev = netlib.getEnslavedDev(guestDev, 1)
    return enslavedDev.split(".")[0]

def handleMigrateBegin():
    try:
        domain = parse(sys.stdin)
        for interface in domain.getElementsByTagName("interface"):
            sources = interface.getElementsByTagName("source")
            if sources.length > 0:
                source = interface.getElementsByTagName("source")[0]
                bridge = source.getAttribute("bridge")
                if isOldStyleBridge(bridge):
                    vlanId = bridge.replace("cloudVirBr", "")
                    phyDev = getGuestNetworkDevice()
                elif isNewStyleBridge(bridge):
                    vlanId = re.sub(r"br(\w+)-", "", bridge)
                    phyDev = re.sub(r"-(\d+)$", "" , re.sub(r"^br", "" ,bridge))
                    netlib = networkConfig()
                    if not netlib.isNetworkDev(phyDev):
                        phyDev = getGuestNetworkDevice()
                else:
                    continue
                newBrName = "br" + phyDev + "-" + vlanId
                source.setAttribute("bridge", newBrName)
        print(domain.toxml())
    except:
        pass


def executeCustomScripts(sysArgs):
    if not os.path.exists(customDir) or not os.path.isdir(customDir):
        return

    scripts = getCustomScriptsFromDirectory()

    for scriptName in scripts:
        executeScript(scriptName, sysArgs)


def executeScript(scriptName, sysArgs):
    logger.info('Executing custom script: %s, parameters: %s' % (scriptName, ' '.join(map(str, sysArgs))))
    path = customDir + os.path.sep + scriptName

    if not os.access(path, os.X_OK):
        logger.warning('Custom script: %s is not executable; skipping execution.' % scriptName)
        return

    try:
        process = subprocess.Popen([path] + sysArgs, stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE, shell=False)
        try:
            timer = Timer(timeoutSeconds, terminateProcess, [process, scriptName])
            timer.start()
            output, error = process.communicate()

            if process.returncode == -15:
                logger.error('Custom script: %s terminated after timeout of %s second[s].'
                             % (scriptName, timeoutSeconds))
                return
            if process.returncode != 0:
                logger.info('return code: %s' % str(process.returncode))
                raise Exception(error)
            logger.info('Custom script: %s finished successfully; output: \n%s' %
                        (scriptName, str(output)))
        finally:
            timer.cancel()
    except (OSError, Exception) as e:
        logger.exception("Custom script: %s finished with error: \n%s" % (scriptName, e))


def terminateProcess(process, scriptName):
    logger.warning('Custom script: %s taking longer than %s second[s]; terminating..' % (scriptName, str(timeoutSeconds)))
    process.terminate()


def getCustomScriptsFromDirectory():
    return sorted([fileName for fileName in os.listdir(customDir) if (fileName is not None) & (fileName != "") & ('_' in fileName) &
                                          (fileName.startswith((action + '_')) | fileName.startswith(('all' + '_')))], key=lambda fileName: substringAfter(fileName, '_'))


def substringAfter(s, delimiter):
    return s.partition(delimiter)[2]


if __name__ == '__main__':
    if len(sys.argv) != 5:
        sys.exit(0)

    # For docs refer https://libvirt.org/hooks.html#qemu
    logger.debug("Executing qemu hook with args: %s" % sys.argv)
    action, status = sys.argv[2:4]

    if action not in validQemuActions:
        logger.error('The given action: %s, is not a valid libvirt qemu operation.' % action)
        sys.exit(0)

    if action == "migrate" and status == "begin":
        handleMigrateBegin()

    executeCustomScripts(sys.argv[1:])
